
import pandas as pd
import os
import logging
LOGGER = logging.getLogger(__name__)
from init import PATHS
import numpy as np


def collecting_ids(list_appln_id, list_docdb_id, test):
    LOGGER.info('collecting_ids: Collect docdb_family_id from TLS201')
    LOGGER.info('       Initial number of appln_ids collected: {}'.format(len(list_appln_id)))
    LOGGER.info('       Initial number of docdb ids collected: {}'.format(len(list_docdb_id)))
    list_files = [f for f in os.listdir(PATHS.patstatglobal) if 'tls201' in f]
    list_files.sort()
    for file in list_files:
        LOGGER.info('       {}'.format(file))
        for num, dfchunk in enumerate(pd.read_csv(PATHS.patstatglobal / file, chunksize=2000000, sep=',', usecols=['appln_id', 'docdb_family_id'])):
            mask1 = dfchunk['appln_id'].isin(list_appln_id)
            mask2 = dfchunk['docdb_family_id'].isin(list_docdb_id)
            dfchunk = dfchunk[mask1 | mask2]
            list_appln_id.extend(dfchunk['appln_id'].drop_duplicates().tolist())
            list_docdb_id.extend(dfchunk['docdb_family_id'].drop_duplicates().tolist())
            if test:
                break
    list_appln_id = list(set(list_appln_id))
    list_docdb_id = list(set(list_docdb_id))
    LOGGER.info('       New number of appln_ids collected: {}'.format(len(list_appln_id)))
    LOGGER.info('       New number of docdb ids collected: {}'.format(len(list_docdb_id)))
    list_appln_id = pd.Series(list_appln_id).rename('appln_ids')
    list_docdb_id = pd.Series(list_docdb_id).rename('docdb_family_id')
    return list_appln_id, list_docdb_id


def collecting_info_on_applnids(list_appln_id, list_docdb_id, namefile, test, small=False):
    if small:
        col_tokeep = ['appln_id', 'appln_auth', 'docdb_family_id', 'granted', 'nb_citing_docdb_fam',
                      'earliest_filing_date', 'earliest_filing_year', 'appln_filing_date', 'docdb_family_size',
                      'nb_applicants', 'nb_inventors', 'earliest_filing_id', 'inpadoc_family_id']
    else:
        col_tokeep = None
    LOGGER.info('Collect All Variables From TLS201')
    list_files = [f for f in os.listdir(PATHS.patstatglobal) if 'tls201' in f]
    df_appinfo = []
    for file in list_files:
        LOGGER.info('       {}'.format(file))
        for num, dfchunk in enumerate(pd.read_csv(PATHS.patstatglobal / file, chunksize=500000, sep=',', low_memory=False, usecols=col_tokeep)):
            mask1 = dfchunk['appln_id'].isin(list_appln_id)
            mask2 = dfchunk['docdb_family_id'].isin(list_docdb_id)
            dfchunk = dfchunk[mask1 | mask2]
            df_appinfo.append(dfchunk)
            if test:
                break
    df_appinfo = pd.concat(df_appinfo)
    LOGGER.info('       Number of patstat id for which info was collected: {}'.format(df_appinfo['appln_id'].nunique()))
    LOGGER.info('       Number of docdb_family_id for which info was collected: {}'.format(df_appinfo['docdb_family_id'].nunique()))
    df_appinfo.to_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/Patents_of_{}.csv'.format(namefile), index=False)
    return


def extract_and_save_cpc_ipc_codes(namefile, test):
    # Goal: we want to save a file, Families_of_{}_cpc_ipc.csv, that contains, for each family id,
    # all the CPC and IPC codes found in Patstat incl. non energy ones - saved as a list in one column
    # and the resulting energy classification, also as a list in one column
    # for example: "batteries, fuel cells"
    df_appinfo = pd.read_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/Patents_of_{}.csv'.format(namefile), usecols=['docdb_family_id', 'appln_id'])
    df_greendirty_codes = pd.read_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/CPC_IPC_codes_allsubgroups.csv')
    df_appinfo_CPC, df_appinfo_CPC_type = collect_and_classify_codes(df_appinfo, 'cpc', df_greendirty_codes, test)
    df_appinfo_IPC, df_appinfo_IPC_type = collect_and_classify_codes(df_appinfo, 'ipc', df_greendirty_codes, test)
    # first dataframe: contains for each applnid all the codes that were found in patstat (i.e., including codes outside of our energy codes list)
    # second dataframe: contains, for each applnid, how that applnid was classified based on our energy codes
    # Aggregating all the IPC and CPC codes at the level of docdb ids, ie saving them as a list in one column
    # CPC codes
    df_appinfo_CPC = df_appinfo.merge(df_appinfo_CPC, how='left', on='appln_id')
    df_appinfo_CPC = df_appinfo_CPC.drop(columns=['appln_id']).drop_duplicates()
    df_appinfo_CPC = df_appinfo_CPC.groupby(['docdb_family_id'], as_index=False).agg(lambda x: ','.join(sorted([i for i in x if type(i) == str])))
    # IPC codes
    df_appinfo_IPC = df_appinfo.merge(df_appinfo_IPC, how='left', on='appln_id')
    df_appinfo_IPC = df_appinfo_IPC.drop(columns=['appln_id']).drop_duplicates()
    df_appinfo_IPC = df_appinfo_IPC.groupby(['docdb_family_id'], as_index=False).agg(lambda x: ','.join(sorted([i for i in x if type(i) == str])))
    # NB: the resulting df_appinfo_CPC and df_appinfo_IPC should have a column "CPC" and "IPC" that contains all the codes as a list
    # if no code were found in patstat, then it will be an empty string: ''. Upon saving as a csv, these empty strings are saved as missing values
    # Combining the energy classifications based on CPC and IPC
    df_appinfo_type = pd.concat([df_appinfo_CPC_type, df_appinfo_IPC_type]).drop_duplicates()
    # Aggregating the energy classifications at family level
    df_appinfo_type = df_appinfo.merge(df_appinfo_type, how='left', on='appln_id')
    df_appinfo_type = df_appinfo_type.drop(columns=['appln_id']).drop_duplicates()
    df_appinfo_type = df_appinfo_type.groupby(['docdb_family_id'], as_index=False).agg(lambda x: ','.join(sorted([i for i in x.unique() if type(i) == str])))
    # Combining all info
    df_appinfo = (df_appinfo_CPC
                  .merge(df_appinfo_IPC, on='docdb_family_id', how='left')
                  .merge(df_appinfo_type, on='docdb_family_id', how='left'))
    df_appinfo.to_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/Families_of_{}_cpc_ipc.csv'.format(namefile), header=True, index=False)
    return


def collect_and_classify_codes(df_appinfo, scheme, df_greendirty_codes, test):
    # scheme = 'cpc' or  scheme = 'ipc'
    LOGGER.info('       Collecting {} codes of the patents associated to OEMs'.format(scheme.upper()))
    name_PATSTAT_table = {'cpc': '224', 'ipc': '209'}
    # Use table tls224 in Patstat to get the patstat id of all the applications falling into specific CPCs
    list_files = [f for f in os.listdir(PATHS.patstatglobal) if name_PATSTAT_table[scheme] in f]
    list_col = ['appln_id', '{}_class_symbol'.format(scheme)]
    df_Patents = []
    for file in list_files:
        LOGGER.info('       File: {}'.format(file))
        for num, dfchunk in enumerate(pd.read_csv(PATHS.patstatglobal / file, chunksize=500000, usecols=list_col, sep=',')):
            # Cleaning out the spaces of CPC codes from PATSTAT to make them comparable to our data
            dfchunk['{}_class_symbol'.format(scheme)] = dfchunk['{}_class_symbol'.format(scheme)].str.replace(' ', '', regex=True)
            # Filtering the codes that we are interested in
            dfchunk = dfchunk[dfchunk['appln_id'].isin(df_appinfo['appln_id'])].rename(columns={'{}_class_symbol'.format(scheme): scheme.upper()})
            df_Patents.append(dfchunk)
            if test:
                break
    df_Patents = pd.concat(df_Patents)
    df_Patents_type = classify_cpc_ipc_codes(df_Patents, df_greendirty_codes, scheme)
    df_Patents_type = df_Patents_type[['appln_id', 'Type', 'Sector', 'Sub-sector', 'BatTrans']].drop_duplicates()
    return df_Patents, df_Patents_type


def classify_cpc_ipc_codes(df_Patents, df_greendirty_codes, scheme):
    # Add variables to indicate if code is in green/dirty transportation
    df_greendirty_codes_full = df_greendirty_codes[df_greendirty_codes['Sector'] != 'Electricity Transport']
    # Add a variable that indicates if battery code is related to transport or not
    df_greendirty_codes_full['BatTrans'] = np.nan
    maskBat = df_greendirty_codes_full['Sub-sector'] == 'batteries'
    maskTrans1 = df_greendirty_codes_full['{}class'.format(scheme.upper())] == 'B60'
    maskTrans2 = df_greendirty_codes_full['{}subclass'.format(scheme.upper())] == 'Y02T'
    mask = maskBat & (maskTrans1 | maskTrans2)
    df_greendirty_codes_full.loc[df_greendirty_codes_full[mask].index, 'BatTrans'] = 'Trans'
    mask = maskBat & ~(maskTrans1 | maskTrans2)
    df_greendirty_codes_full.loc[df_greendirty_codes_full[mask].index, 'BatTrans'] = 'NonTrans'
    # Codes with Electricity Transport are grey wrt elec and dirty wrt transpo. no need to have both categorization. let's keep them only as dirty trans.
    df_greendirty_codes_full = df_greendirty_codes_full[['{}full'.format(scheme.upper()), 'Type', 'Sector', 'Sub-sector', 'BatTrans']].drop_duplicates()
    df_Patents = df_Patents.merge(df_greendirty_codes_full, how='left', left_on=scheme.upper(), right_on='{}full'.format(scheme.upper()))
    if scheme == 'ipc':
        # The IPC table in PATSTAT lists some codes that are at the subclass level (4-digit long codes)
        # we therefore also need to check any subclass level code
        # i.e. if a patent is listed as having the IPC code "C10J" in PATSTAT, we want to capture it
        # NB: I checked whether all the codes listed in the CPC and IPC tables in PATSTAT. all codes are at the subgroup level except a few ipc at the subclass level
        mask1 = df_greendirty_codes['Code'] == df_greendirty_codes['IPCsubclass']
        mask2 = df_greendirty_codes['Code'] == df_greendirty_codes['IPCclass']
        df_greendirty_codes4digit = df_greendirty_codes[mask1 | mask2][['IPCsubclass', 'Type', 'Sector', 'Sub-sector']].drop_duplicates()
        df_greendirty_codes4digit = df_greendirty_codes4digit.rename(columns={'Type': 'Type_4dg', 'Sector': 'Sector_4dg', 'Sub-sector': 'Sub-sector_4dg'})
        df_Patents = df_Patents.merge(df_greendirty_codes4digit, how='left', left_on=scheme.upper(), right_on='IPCsubclass')
        df_Patents['Type'] = df_Patents['Type'].fillna(df_Patents['Type_4dg'])
        df_Patents['Sector'] = df_Patents['Sector'].fillna(df_Patents['Sector_4dg'])
        df_Patents['Sub-sector'] = df_Patents['Sub-sector'].fillna(df_Patents['Sub-sector_4dg'])
        df_Patents = df_Patents[['appln_id', 'IPC', 'IPCfull', 'Type', 'Sector', 'Sub-sector', 'BatTrans']]
    df_Patents = df_Patents[df_Patents['Sub-sector'].notnull()]
    return df_Patents


def extract_and_save_nace_codes(namefile, test):
    df_appinfo = pd.read_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/Patents_of_{}.csv'.format(namefile), usecols=['docdb_family_id', 'appln_id'])
    df_appinfo = collect_nace_codes(df_appinfo, test)
    # The total of all weights of an application always equals 1
    df_appinfo = df_appinfo.drop(columns=['appln_id']).drop_duplicates()
    df_appinfo['weight'] = df_appinfo['weight'].apply('{:,.2f}'.format)
    df_appinfo['nace2_code'] = df_appinfo['nace2_code'].apply('{:,.2f}'.format)
    NACEs = pd.read_csv(PATHS.patstatglobal / 'tls902_part01.csv', sep=',', usecols=['nace2_code', 'nace2_descr']).drop_duplicates()
    NACEs = NACEs[NACEs['nace2_descr'].notnull()]
    NACEs['nace2_code'] = NACEs['nace2_code'].astype(float)
    NACEs['nace2_code'] = NACEs['nace2_code'].apply('{:,.2f}'.format)
    df_appinfo = df_appinfo.merge(NACEs, on='nace2_code', how='left')
    df_appinfo = df_appinfo.groupby(['docdb_family_id'], as_index=False).agg(lambda x: ','.join(sorted([str(i) for i in x if type(i) == str])))
    df_appinfo.to_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/Families_of_{}_nace.csv'.format(namefile), index=False)
    return


def collect_nace_codes(df_appinfo, test):
    LOGGER.info('Collect nace codes from TLS229')
    list_files = [f for f in os.listdir(PATHS.patstatglobal) if 'tls229' in f]
    df_nace = []
    for file in list_files:
        LOGGER.info('       {}'.format(file))
        for num, dfchunk in enumerate(pd.read_csv(PATHS.patstatglobal / file, chunksize=500000, sep=',', low_memory=False)):
            mask1 = dfchunk['appln_id'].isin(df_appinfo['appln_id'])
            dfchunk = dfchunk[mask1]
            df_nace.append(dfchunk)
            if test:
                break
    df_nace = pd.concat(df_nace)
    df_appinfo = df_appinfo.merge(df_nace, on='appln_id', how='left')
    return df_appinfo


def extract_and_save_psnsector(list_appln_id, namefile, test):
    LOGGER.info('extract_and_save_psnsector')
    df_personid = collecting_applnid_of_applicants(list_appln_id, test)
    df_personid = adding_psn_sector(df_personid, test)
    df_appinfo = pd.read_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/Patents_of_{}.csv'.format(namefile), usecols=['docdb_family_id', 'appln_id'])
    df_appinfo = df_appinfo.merge(df_personid, on='appln_id', how='left')
    df_appinfo = df_appinfo.groupby(['docdb_family_id'], as_index=False).agg(lambda x: ','.join(sorted([i for i in x.unique() if type(i) == str])))
    df_appinfo = df_appinfo.drop(columns=['appln_id', 'person_id'])
    df_appinfo.to_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/Families_of_{}_psnsector.csv'.format(namefile), index=False)
    return


def collecting_applnid_of_applicants(list_appln_id, test):
    LOGGER.info('       Collecting all person_id of applicants: looping through tls207')
    df_personid = []
    for num, dfchunk in enumerate(pd.read_csv(PATHS.patstatglobal / 'tls207_part01.csv', chunksize=2000000, sep=',', usecols=['person_id', 'appln_id', 'applt_seq_nr'])):
        mask1 = dfchunk['appln_id'].isin(list_appln_id)
        mask2 = dfchunk['applt_seq_nr'] > 0   # will be true only for applicants
        dfchunk = dfchunk[mask1 & mask2]
        df_personid.append(dfchunk)
        if test:
            break
    df_personid = pd.concat(df_personid)
    df_personid = df_personid.drop(columns=['applt_seq_nr'])
    return df_personid


def adding_psn_sector(df_personid, test):
    LOGGER.info('       Collecting person info: looping through tls206')
    list_files = [f for f in os.listdir(PATHS.patstatglobal)if 'tls206' in f]
    df_psn_sector = []
    coltokeep = ['person_id', 'psn_sector']
    for file in list_files:
        LOGGER.info('       {}'.format(file))
        for num, dfchunk in enumerate(pd.read_csv(PATHS.patstatglobal / file, chunksize=500000, sep=',', usecols=coltokeep)):
            mask1 = dfchunk['person_id'].isin(df_personid['person_id'].drop_duplicates().tolist())
            dfchunk = dfchunk[mask1]
            df_psn_sector.append(dfchunk)
            if test:
                break
    df_psn_sector = pd.concat(df_psn_sector)
    df_personid = df_personid.merge(df_psn_sector, on='person_id', how='left')
    return df_personid



